Skip to content

feat: implement rae autoencoder.#13046

Open
Ando233 wants to merge 33 commits intohuggingface:mainfrom
Ando233:rae
Open

feat: implement rae autoencoder.#13046
Ando233 wants to merge 33 commits intohuggingface:mainfrom
Ando233:rae

Conversation

@Ando233
Copy link

@Ando233 Ando233 commented Jan 28, 2026

What does this PR do?

This PR adds a new representation autoencoder implementation, AutoencoderRAE, to diffusers.
Implements diffusers.models.autoencoders.autoencoder_rae.AutoencoderRAE with a frozen pretrained vision encoder (DINOv2 / SigLIP2 / ViT-MAE) and a ViT-MAE style decoder.
The decoder implementation is aligned with the RAE-main GeneralDecoder parameter structure, enabling loading of existing trained decoder checkpoints (e.g. model.pt) without key mismatches when encoder/decoder settings are consistent.
Adds unit/integration tests under diffusers/tests/models/autoencoders/test_models_autoencoder_rae.py.
Registers exports so users can import directly via from diffusers import AutoencoderRAE.

Fixes #13000

Before submitting

Usage

ae = AutoencoderRAE(
    encoder_cls="dinov2",
    encoder_name_or_path=encoder_path,
    image_size=image_size,
    encoder_input_size=image_size,
    patch_size=patch_size,
    num_patches=num_patches,
    decoder_hidden_size=1152,
    decoder_num_hidden_layers=28,
    decoder_num_attention_heads=16,
    decoder_intermediate_size=4096,
).to(device)
ae.eval()

state = torch.load(args.decoder_ckpt, map_location="cpu")
ae.decoder.load_state_dict(state, strict=False)

with torch.no_grad():
    recon = ae(x).sample

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from kashif January 30, 2026 11:31
@sayakpaul
Copy link
Member

@bytetriper if you could take a look?

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

nice works @Ando233 checking

@kashif
Copy link
Contributor

kashif commented Jan 30, 2026

off the bat,

  • let's have a nice convention for the output datatype classes, have a look at the other autoencoder for the convention in difusers
  • some of the tests might need to be marked as slow and some paths are hard-coded

lets sort out these things and then re-look

@bytetriper
Copy link

Agree with @kashif . Also if possible we can bake all the params into config so we can enable .from_pretrained(), which is more elegant and aligns with diffusers usage. I can help convert our released ckpt to hgf format afterwards

@sayakpaul
Copy link
Member

@Ando233 we're happy to provide assistance if needed.

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@Ando233 the one remaining thing is the use of the use_encoder_loss and perhaps an example real-world training script

@kashif
Copy link
Contributor

kashif commented Feb 15, 2026

@bytetriper could you kindly try to run the conversion scripts and upload the diffusers style weights to your huggingface hub for the checkpoints you have?

@Ando233
Copy link
Author

Ando233 commented Feb 17, 2026

Thank you for efforts @kashif , let me try to implement the remaining use_encoder_loss and real-world training script

@kashif
Copy link
Contributor

kashif commented Feb 17, 2026

@Ando233 I added that already, so next we can wait for @bytetriper for a review and see if the weight conversion works on his end

@bytetriper
Copy link

Thanks for the implementation! I just checked and weight conversion works on my end. Converted models are under https://huggingface.co/collections/nyu-visionx/rae. @kashif @Ando233 Can you check whether the converted models work on your end?

@sayakpaul
Copy link
Member

@bytetriper thanks! What would be the quickest way to validate if the implementation is correct? We can do a quick value assertion test between the original model and the converted model on the same inputs. Would you be able to do it?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a bunch of comments. The major thing is we need to be a bit more explicit in terms of how we're defining the configs, loading encoder state dicts, etc.

I think we could aim for the following entrypoint for instantiating the AutoencoderRAE class:

AutoencoderRAE(..., encoder_type="dinov2")

Inside the implementation of AutoencoderRAE __init__(), specifically, we can have a simple if/else block to dispatch the encoder based on encoder_type.:

if encoder_type == "dinov2":
    encoder = Dinov2Encoder()
elif encoder_type == "siglip2":
    encoder = Siglip2Encoder()
...

And then, when a user does AutoencoderRAE.from_pretrained(...), the state dict should have both the encoder and decoder state dict, following how it's done in the other Autoencoder implementations of diffusers.

I will also let @dg845 take a look and provide feedback.

Comment on lines 15 to 21
`AutoencoderRAE` is a representation autoencoder that combines a frozen vision encoder (DINOv2, SigLIP2, or MAE) with a ViT-MAE-style decoder.

Paper: [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690).

The model follows the standard diffusers autoencoder API:
- `encode(...)` returns an `EncoderOutput` with a `latent` tensor.
- `decode(...)` returns a `DecoderOutput` with a `sample` tensor.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @stevhliu. Could you leave suggestions on the docs?


model = AutoencoderRAE(
encoder_cls="dinov2",
encoder_name_or_path="facebook/dinov2-with-registers-base",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Org should nyu-visionx.

- `encode(...)` returns an `EncoderOutput` with a `latent` tensor.
- `decode(...)` returns a `DecoderOutput` with a `sample` tensor.

## Usage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif does this need updating?


For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders).

See `examples/research_projects/autoencoder_rae/train_autoencoder_rae.py` for a stage-1 style training script
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does stage-2 have? Generation?


`encoder_cls` supports `"dinov2"`, `"siglip2"`, and `"mae"`.

For latent normalization, use `latents_mean` and `latents_std` (matching other diffusers autoencoders).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should provide an example for this.

Comment on lines +66 to +67
self.model.layernorm.weight = None
self.model.layernorm.bias = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These params are not used in the forward pass anyway. So, maybe it's not needed?

Comment on lines 537 to 541
from transformers import AutoImageProcessor

proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
encoder_mean = torch.tensor(proc.image_mean, dtype=torch.float32).view(1, 3, 1, 1)
encoder_std = torch.tensor(proc.image_std, dtype=torch.float32).view(1, 3, 1, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be explicitly in the conversion script. This is an antipattern for the library.

We could do something like:
https://github.com/huggingface/diffusers/blob/a80b19218b4bd4faf2d6d8c428dcf1ae6f11e43d/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py#L1112C9-L1116C1

Then in the conversion script, make these a part of the converted state dict before loading that into the diffusers implementation. LMK if it's unclear.

Comment on lines 546 to 556
# Optional latent normalization (RAE-main uses mean/var)
latents_mean_tensor = _as_optional_tensor(latents_mean)
self.do_latent_normalization = latents_mean is not None or latents_std is not None
if latents_mean_tensor is not None:
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
else:
self._latents_mean = None
if latents_std_tensor is not None:
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
else:
self._latents_std = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this can be removed?

if encoder_hidden_size is None:
raise ValueError(f"Encoder '{encoder_cls}' must define `.hidden_size` attribute.")

decoder_config = SimpleNamespace(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

- trainable_cls_token
"""

def __init__(self, config, num_patches: int):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should split out the config and expand the __init__ args here. That's how it's done in diffusers.

@sayakpaul sayakpaul requested a review from dg845 February 23, 2026 03:25
@bytetriper
Copy link

@bytetriper thanks! What would be the quickest way to validate if the implementation is correct? We can do a quick value assertion test between the original model and the converted model on the same inputs. Would you be able to do it?

I tested and the converted model produce identical output on my end up to some small numerical differences. Just want to make sure it also has the same behavior on other's end:)

I generally agree that we should have encoder in the ckpt as well. Can help for conversion afterewards

@sayakpaul
Copy link
Member

Cool then. I will give you a heads up when the PR is ready for another look. Thank you!

@kashif
Copy link
Contributor

kashif commented Feb 23, 2026

@bytetriper i sent you some fixes to the weights if you can kindly merge

@bytetriper
Copy link

@kashif Merged!

@kashif kashif requested a review from sayakpaul February 26, 2026 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RAE support

4 participants